from collections import Counter
import pybedtools
from matplotlib import patches
from pylab import *
from numpy import *



# genes
categories_skip = set(["antisense", "prompt", "antisense_distal", "antisense_distal_upstream", "roadmap_dyadic", "roadmap_enhancer", "FANTOM5_enhancer", "novel_enhancer_CAGE", "novel_enhancer_HiSeq", "other"])
categories_keep = set(["sense_proximal", "sense_upstream", "sense_distal", "sense_distal_upstream"])


ppvalues = {}

filename = "peaks.gff"
print("Reading %s" % filename)
lines = pybedtools.BedTool(filename)
for line in lines:
    fields = line.fields
    feature = line.fields[2]
    category = feature
    if category in categories_skip:
        continue
    if category not in categories_keep:
        raise Exception("Unknown category %s" % category)
    name = "%s_%d-%d_%s" % (line.chrom, line.start, line.end, line.strand)
    ppvalue = float(line.attrs['ppvalue'])
    ppvalues[name] = ppvalue

filename = "peaks.expression.txt"
print("Reading", filename)
handle = open(filename)
line = next(handle)
words = line.split()
assert words[0] == 'peak'
assert words[1] == 'HiSeq_t00_r1'
assert words[2] == 'HiSeq_t00_r2'
assert words[3] == 'HiSeq_t00_r3'
assert words[4] == 'HiSeq_t01_r1'
assert words[5] == 'HiSeq_t01_r2'
assert words[6] == 'HiSeq_t04_r1'
assert words[7] == 'HiSeq_t04_r2'
assert words[8] == 'HiSeq_t04_r3'
assert words[9] == 'HiSeq_t12_r1'
assert words[10] == 'HiSeq_t12_r2'
assert words[11] == 'HiSeq_t12_r3'
assert words[12] == 'HiSeq_t24_r1'
assert words[13] == 'HiSeq_t24_r2'
assert words[14] == 'HiSeq_t24_r3'
assert words[15] == 'HiSeq_t96_r1'
assert words[16] == 'HiSeq_t96_r2'
assert words[17] == 'HiSeq_t96_r3'
assert words[18] == 'CAGE_00_hr_A'
assert words[19] == 'CAGE_00_hr_C'
assert words[20] == 'CAGE_00_hr_G'
assert words[21] == 'CAGE_00_hr_H'
assert words[22] == 'CAGE_01_hr_A'
assert words[23] == 'CAGE_01_hr_C'
assert words[24] == 'CAGE_01_hr_G'
assert words[25] == 'CAGE_04_hr_C'
assert words[26] == 'CAGE_04_hr_E'
assert words[27] == 'CAGE_12_hr_A'
assert words[28] == 'CAGE_12_hr_C'
assert words[29] == 'CAGE_24_hr_C'
assert words[30] == 'CAGE_24_hr_E'
assert words[31] == 'CAGE_96_hr_A'
assert words[32] == 'CAGE_96_hr_C'
assert words[33] == 'CAGE_96_hr_E'
hiseq = {}
cage = {}
for line in handle:
    words = line.split()
    assert len(words) == 34
    name = words[0]
    if name not in ppvalues:
        continue
    hiseq[name] = sum([float(word) for word in words[1:18]])
    cage[name] = sum([float(word) for word in words[18:34]])
handle.close()

ppvalue_threshold = -log10(0.05)
counts = Counter()
for name in ppvalues:
    key = []
    if hiseq[name] > 1e-20:
        key.append("hiseq_expressed")
    else:
        key.append("hiseq_not_expressed")
    if cage[name] > 1e-20:
        key.append("cage_expressed")
    else:
        key.append("cage_not_expressed")
    if ppvalues[name] > +ppvalue_threshold:
        key.append("hiseq_significant")
    else:
        key.append("hiseq_not_significant")
    if ppvalues[name] < -ppvalue_threshold:
        key.append("cage_significant")
    else:
        key.append("cage_not_significant")
    key = tuple(key)
    counts[key] += 1

for key in counts:
    print(key, counts[key])

circle_hiseq_expressed = Circle((-1., 0), 3, color='r', alpha=0.2)
circle_cage_expressed = Circle((+1., 0), 3, color='b', alpha=0.2)
circle_hiseq_significant = Circle((-2, 0), 1.75, color='r', alpha=0.2)
circle_cage_significant = Circle((+2, 0), 1.75, color='b', alpha=0.2)

f = figure()
ax = subplot(111)

ax.add_patch(circle_hiseq_expressed)
ax.add_patch(circle_cage_expressed)
ax.add_patch(circle_hiseq_significant)
ax.add_patch(circle_cage_significant)

axis('square')
xlim(-5,5)
ylim(-5,5)


number1 = counts[('hiseq_expressed', 'cage_expressed', 'hiseq_not_significant', 'cage_not_significant')]
text(0, -2, str(number1), horizontalalignment='center', verticalalignment='center')
number2 = counts[('hiseq_not_expressed', 'cage_expressed', 'hiseq_not_significant', 'cage_not_significant')]
text(+1.95, -2.25, str(number2), horizontalalignment='center', verticalalignment='center')
number3 = counts[('hiseq_expressed', 'cage_not_expressed', 'hiseq_not_significant', 'cage_not_significant')]
text(-1.95, -2.25, str(number3), horizontalalignment='center', verticalalignment='center')
number4 = counts[('hiseq_expressed', 'cage_expressed', 'hiseq_significant', 'cage_not_significant')]
text(-1.2, 0, str(number4), horizontalalignment='center', verticalalignment='center')
number5 = counts[('hiseq_expressed', 'cage_expressed', 'hiseq_not_significant', 'cage_significant')]
text(+1.2, 0, str(number5), horizontalalignment='center', verticalalignment='center')
number6 = counts[('hiseq_expressed', 'cage_not_expressed', 'hiseq_significant', 'cage_not_significant')]
text(-2.9, 0, str(number6), horizontalalignment='center', verticalalignment='center')
number7 = counts[('hiseq_not_expressed', 'cage_expressed', 'hiseq_not_significant', 'cage_significant')]
text(+2.9, 0, str(number7), horizontalalignment='center', verticalalignment='center')
text(-2, 3, 'HiSeq', color='red', horizontalalignment='center', verticalalignment='bottom')
text(+2, 3, 'CAGE', color='blue', horizontalalignment='center', verticalalignment='bottom')

hiseq_expressed_patch = patches.Patch(color='red', alpha=0.2, label='Expressed as\nshort capped RNAs')
cage_expressed_patch = patches.Patch(color='blue', alpha=0.2, label='Expressed as\nlong capped RNAs')
hiseq_significant_patch = patches.Patch(color='red', alpha=0.4, label='Significantly enriched\nin short capped RNA\n(single-end) libraries')
cage_significant_patch = patches.Patch(color='blue', alpha=0.4, label='Significantly enriched\nin long capped RNA\n(CAGE) libraries')

legend(handles=[hiseq_expressed_patch, hiseq_significant_patch, cage_expressed_patch, cage_significant_patch], fontsize=8, loc='lower center', ncol=2,bbox_to_anchor=(0.5,-0.05))

title("Gene-associated peaks")

axis('off')

filename = "figure_peak_venn_diagram.svg"
print("Saving figure to", filename)
savefig(filename)

filename = "figure_peak_venn_diagram.png"
print("Saving figure to", filename)
savefig(filename)

print("Total number of gene-associated peaks: %d" % len(ppvalues))
print("Expressed in HiSeq and CAGE: %d" % (number1 + number4 + number5))
print("Expressed in HiSeq only: %d" % (number3 + number6))
print("Expressed in CAGE only: %d" % (number2 + number7))
print("Significantly higher expression in HiSeq: %d" % (number4 + number6))
print("Significantly higher expression in CAGE: %d" % (number5 + number7))
